Skip to content

Cast GroupNorm fp16 backward gradient accumulators to fp16 (not bfloat16)#1253

Open
lollinng wants to merge 1 commit into
linkedin:mainfrom
lollinng:fix/groupnorm-fp16-backward-dtype
Open

Cast GroupNorm fp16 backward gradient accumulators to fp16 (not bfloat16)#1253
lollinng wants to merge 1 commit into
linkedin:mainfrom
lollinng:fix/groupnorm-fp16-backward-dtype

Conversation

@lollinng

@lollinng lollinng commented Jun 4, 2026

Copy link
Copy Markdown

Problem

group_norm_backward picks the dtype its dW/dB gradient accumulators are cast to before the atomic_add into the DW/DB buffers:

triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16

The DW/DB buffers are allocated with W.dtype / B.dtype. For an fp16 model that's an fp16 buffer, but triton_dtype is bfloat16 — so the fp16 gradients are rounded to bfloat16 (8 mantissa bits) and then atomic-added into an fp16 buffer (10 mantissa bits), losing precision for no reason. bf16 and fp16 are not interchangeable.

Fix

Map fp16 -> tl.float16 so the accumulator/atomic dtype matches the buffer dtype. bf16 inputs are unchanged.

Testing done (NVIDIA T4)

GroupNorm forward+backward vs torch.nn.functional.group_norm, comparing the weight gradient dW[0] (ref = 4.981):

                 fp32            fp16
before (main)    dW=4.981        dW=4.996   (bf16-rounded accumulator)
after  (this PR) dW=4.981        dW=4.984   (closer to the fp32 reference)

Both pass the tolerance check; the fix moves the fp16 gradient measurably closer to the reference. Honesty note: on the Triton version tested the old bf16 -> fp16 atomic_add is silently coerced rather than erroring, so this is a precision/type-consistency fix, not a crash fix.

group_norm_backward set triton_dtype = bfloat16 for every non-fp32 input.
For fp16 inputs that means the dW/dB gradient accumulators were rounded to
bfloat16 (8 mantissa bits) before being atomic-added into the fp16 DW/DB
buffers (10 mantissa bits) -- losing precision for no reason, since the
buffers are fp16. Map fp16 -> fp16 so the atomic_add dtype matches the
buffer dtype.

This is a type-consistency / precision fix, not a crash fix: on the Triton
version tested the bf16->fp16 atomic_add is silently coerced rather than
erroring.

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant